
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
import matplotlib 
import os
import pandas as pd
import scipy.optimize
from numpy.random import rand,normal
import pickle
from scipy.stats import sem

def rotation_fit(t,t0,t1,A,w,phi,c):
    """Fit for rotation data
    """
    V =A*np.cos(w*(t-t0)+phi)+c
    V = np.where(t<t0,A*np.cos(phi)+c,V)
    V = np.where(t>t1,A*np.cos(w*(t1-t0)+phi)+c,V)

    return V

def persistance_plot(t_values,data_list,axes =None,vbins = None):
    """Plot a list of traces with same time base as a 2D histogram
    Imitates persistance mode of an oscilloscope
    Designed to to work with digitised data from picoscope
    """
    if not axes:
        plt.figure()
    if not vbins:
        vbins=len(np.unique(data_list[0]))
    print("Number of vertical bins",vbins)
    
    plt.hist2d(np.tile(t_values, len(data_list)), np.ravel(data_list), bins=(int(len(t_values)/4),vbins),norm= matplotlib.colors.LogNorm())
    plt.colorbar()
    
    #plt.show()


# t_values = np.linspace(0,1,10000)
# Vs = rotation_fit(t_values,0.5,1.8,1,1*np.pi,np.pi/4,0.2)

# persistance_data = []
# for i in range(1000):
#     t0 = 0.1#+0.1*rand(1)
#     t1 = 0.9#-0.01*rand(1)
#     A = 0.05*normal(1)+0.5
#     w = (0.7+0.05*normal(1))*2*np.pi
#     phi = (-0.1+0.05*normal(1))*2*np.pi
#     c = 0
#     # print(params)

#     Vs = rotation_fit(t_values,t0,t1,A,w,phi,c)

#     fakedata = Vs +0.05*np.random.rand(len(Vs))
#     persistance_data.append(fakedata)
    # plt.plot(t_values,fakedata)

    # try:
    #     fit,pcov = scipy.optimize.curve_fit(rotation_fit,t_values,fakedata,p0 = [0.2,0.9,0.5,4,0,0.1])
    #     print("Fit",fit,"err",np.sqrt(np.diag(pcov)))
    #     plt.plot(t_values,rotation_fit(t_values,*fit))
    #     plt.title("Fit t0={:.1f} t1={:.1f} A={:.1f} w={:.1f} phi={:.1f} c={:.1f}".format(*fit))
    # except:
    #     print("fit failed")
# plt.figure()
# plt.hist2d(np.tile(t_values, len(persistance_data)), np.ravel(persistance_data), bins=(200, int(len(t_values)/2)))#,norm= matplotlib.colors.LogNorm())
# plt.show()


folder_path = r"Y:\Microscope\People\adarsh\Hollow stepper\\"

measure = "one_way_overnight_mask\\"
path = folder_path+measure

V_Alist= []
V_Blist = []
i = 0
read = True

if read:
    # Read in data
    os.makedirs(os.path.dirname("testingdata\\"+measure),exist_ok=True)

    for file in os.listdir(path):
            print(file)

            if file.endswith(".mat"):
                
                matfile = scipy.io.loadmat(path+file)

                print(matfile.keys())
                tstart = matfile['Tstart'][0][0]
                # print(tstart)
                samples = matfile['RequestedLength'][0][0] + \
                    matfile['ExtraSamples'][0][0]
                times = np.linspace(tstart, tstart+samples *
                                    matfile['Tinterval'][0][0], samples)
                #voltageB = np.array(matfile['B'])
                voltageA = np.array(matfile['A'])
                V_Alist.append(voltageA)
                #V_Blist.append(voltageB)
                i+=1
                # plt.plot(times,voltageA)
                # plt.show()
                # if i>5:
                    # break
    with open("testingdata\\"+measure+"V_Adata.pkl",'wb') as f:
        pickle.dump(V_Alist,f)
    with open("testingdata\\"+measure+"V_Bdata.pkl",'wb') as f:
        pickle.dump(V_Blist,f)
    with open("testingdata\\"+measure+"times.pkl",'wb') as f:
        pickle.dump(times,f)


with open("testingdata\\"+measure+"V_Adata.pkl",'rb') as f:
    V_Alist = pickle.load(f)
with open("testingdata\\"+measure+"times.pkl",'rb') as f:
    times = pickle.load(f)
#with open("testingdata\\"+measure+"V_Bdata.pkl",'rb') as f:
 #   V_Blist = pickle.load(f)
print(len(V_Alist))






figurepath = "figures\\"+measure
os.makedirs(os.path.dirname(figurepath),exist_ok=True)
#%%
risesA = V_Alist
risesB = V_Blist

#%%
risesA_array=np.array(risesA)
risesB_array=np.array(risesB)
power_a=(5.1*(risesA_array+0.15))+.010


#angle_a=57.3*np.arcsin(rel_a**0.5)

plt.plot(times,power_a[0])
#%%


#%%
# A quick slideshow of the traces. 
import time
rep_num=8 # Sets which images in a series I want to look at. If set to 4, 
            # I am looking at every 4th image. For a wp rotating 45 degrees
            #between each image, this means I am looking at images every 180
start=[]
high=[]   
st_sd=[]
high_sd=[]
smooth_power_a=np.zeros([2000,5000])         
for a in np.arange(0,2000,rep_num):
    for b in np.arange(0,len(power_a[a,:]-21),20):
        smooth_power_a[a,int(b/20)]=np.mean(power_a[a,b:b+20])
    #plt.plot(smooth_power_a[a,:]/np.max(smooth_power_a[a,:]),color='b')
    start.append(np.mean(power_a[a,0:4000]))
    high.append(np.max(smooth_power_a[a,:]))
    st_sd.append(np.std(smooth_power_a[a,0:200]))
    high_sd.append(np.std(smooth_power_a[a,550:570]))
high=np.array(high)

#%%
#Testing code
plt.plot(high,'o')
plt.plot(1.92*start,'o')
plt.xlim(0,50)
plt.legend(['Highest power in a switch','Starting point of the switch x 1.92 (mW)'])

plt.xlabel('Switch number')
plt.ylabel('Power (mW)')

plt.show()
plt.plot(start/high,'o',c='darkred')
plt.xlabel('Switch number')
plt.ylabel('Relative intensity')
#%%
print(np.std(rel))
print(np.std(smooth_power_a[88,0:100]/high[0]))
plt.plot(high,start,'o',c='gold')
plt.xlabel('High')
plt.ylabel('Start')
#%%
rel=start/high
av=np.mean(rel)
dev=np.std(rel)
start=np.array(start)
st_err=st_sd/(np.sqrt(200)*start)
high_err=high_sd/(np.sqrt(20)*high)
deverr=np.sqrt(np.array(st_sd)**2 + np.array(high_sd)**2)
err= np.sqrt(np.array(st_err)**2 + np.array(high_err)**2)
plt.plot(high_sd/high,'o',color='darkred',)
plt.plot(st_sd/start,'s',color='darkgreen',)
plt.title('Standard deviation for each measurement')
plt.legend(['Peak value', 'Start value'])
plt.show()
plt.errorbar(np.arange(0,250),rel,yerr=err,c='k',fmt='o')
plt.plot([0,250],[1.03*np.mean(rel),1.03*np.mean(rel)],c='r')
plt.plot([0,250],[.97*np.mean(rel),.97*np.mean(rel)],c='r')
plt.title('Relative intensity for each measurement')
plt.show()
plt.plot(high_err,'o',color='darkred',)
plt.plot(st_err,'s',color='darkgreen',)
plt.title('Standard error for each measurement')
plt.legend(['Peak value', 'Start value'])
#plt.plot(err,'o',c='purple')

#%%
averr=np.mean(err)
plt.hist(rel,bins=30,color='gold')
plt.plot([av-dev,av-dev],[0,30],c='darkgreen')
plt.plot([av+dev,av+dev],[0,30],c='darkgreen')
plt.plot([av-averr,av-averr],[0,30],c='m')
plt.plot([av+averr,av+averr],[0,30],c='m')